//
// Created by jiash on 28/12/2021.
//

#include <stdint.h>
#include <netinet/in.h>
#include "nbd_cmd.h"
#include "nbd.h"
#include "cliserv.h"
#include "nbdsrv.h"
#include <errno.h>
#include <pthread.h>
#include <stdio.h>
#include <assert.h>
#include "exp_async_io.h"

static uint32_t nbd_errno(int errcode)
{
    switch (errcode) {
        case EPERM:
            return htonl(1);
        case EIO:
            return htonl(5);
        case ENOMEM:
            return htonl(12);
        case EINVAL:
            return htonl(22);
        case EFBIG:
        case ENOSPC:
#ifdef EDQUOT
            case EDQUOT:
#endif
            return htonl(28); // ENOSPC
        default:
            return htonl(22); // EINVAL
    }
}

struct expand_sqe g_sqe;
struct expand_sqe *sqe_create(struct nbd_request *req)
{
    struct expand_sqe *sqe = &g_sqe;
    sqe->opcode = ntohl(req->type) & NBD_CMD_MASK_COMMAND;
    sqe->offset = ntohll(req->from);
    sqe->len = ntohl(req->len);
    sqe->cid = req->cid;
//    printf("%s: cid=%lx, type=%x, from=%ld, len=%d\n", __func__, req->cid, sqe->opcode, sqe->offset, sqe->len);
    sqe->user_data = 0;
    return sqe;
}

static void setup_reply(struct nbd_reply *rep, struct expand_sqe *sqe)
{
    rep->magic = htonl(NBD_REPLY_MAGIC);
    rep->error = 0;
    rep->cid = sqe->cid;
}

static void handle_read(CLIENT *client, struct expand_sqe *sqe)
{
    struct exp_async_io *exp = &client->exp_io;
    void *buf = malloc(sqe->len);
    if (!buf) {
        err("Could not allocate memory for request");
    }

    struct nbd_reply rep;
    setup_reply(&rep, sqe);

    sqe->user_data = (uint64_t) buf;
    if (exp->submit_io(exp, sqe)) {
//    if (exp->read(exp, req->from, buf, req->len)) {
        rep.error = nbd_errno(errno);
    }
    pthread_mutex_lock(&(client->lock));
    writeit(client->net, &rep, sizeof(rep));
    if (!rep.error) {
        writeit(client->net, buf, sqe->len);
    }
    pthread_mutex_unlock(&(client->lock));
    free(buf);
    sqe->user_data = 0;
}

static void handle_write(CLIENT *client, struct expand_sqe *sqe)
{
    struct exp_async_io *exp = &client->exp_io;
    struct nbd_reply rep;
//    int fua = (req->type & NBD_CMD_FLAG_FUA != 0);

    setup_reply(&rep, sqe);

    if (exp->submit_io(exp, sqe)) {
//    if (exp->write(exp, req->from, data, req->len, fua)) {
        rep.error = nbd_errno(errno);
    }
    pthread_mutex_lock(&(client->lock));
    writeit(client->net, &rep, sizeof rep);
    pthread_mutex_unlock(&(client->lock));
}

static void handle_flush(CLIENT *client, struct expand_sqe *sqe)
{
    struct exp_async_io *exp = &client->exp_io;
    struct nbd_reply rep;

    setup_reply(&rep, sqe);
    if (exp->flush(exp)) {
        rep.error = nbd_errno(errno);
    }
    pthread_mutex_lock(&(client->lock));
    writeit(client->net, &rep, sizeof rep);
    pthread_mutex_unlock(&(client->lock));
}

static void handle_trim(CLIENT *client, struct expand_sqe *sqe)
{
    struct exp_async_io *exp = &client->exp_io;
    struct nbd_reply rep;
    setup_reply(&rep, sqe);
    if (exp->submit_io(exp, sqe)) {
        rep.error = nbd_errno(errno);
    }
    pthread_mutex_lock(&(client->lock));
    writeit(client->net,&rep, sizeof rep);
    pthread_mutex_unlock(&(client->lock));
}

static bool bad_range(CLIENT *client, struct expand_sqe *sqe)
{
    if (sqe->offset > client->exportsize ||
        sqe->offset + sqe->len > client->exportsize) {
        return true;
    }
    return false;
}

static void handle_request(CLIENT *client, struct expand_sqe *sqe)
{
    int err = EINVAL;

    switch (sqe->opcode) {
        case NBD_CMD_READ:
            if (bad_range(client, sqe)) {
                goto error;
            }
            handle_read(client, sqe);
            break;
        case NBD_CMD_WRITE:
            if (bad_range(client, sqe)) {
                err = ENOSPC;
                goto error;
            }
            handle_write(client, sqe);
            break;
        case NBD_CMD_FLUSH:
            handle_flush(client, sqe);
            break;
        case NBD_CMD_TRIM:
            if (bad_range(client, sqe)) {
                goto error;
            }
            handle_trim(client, sqe);
            break;
        default:
            msg(LOG_ERR, "E: received unknown command %d of type, ignoring", package->req->type);
            goto error;
    }
    goto end;

    error:
    {
        struct nbd_reply rep;
        setup_reply(&rep, sqe);
        rep.error = nbd_errno(err);
        pthread_mutex_lock(&(client->lock));
        writeit(client->net, &rep, sizeof rep);
        pthread_mutex_unlock(&(client->lock));
    }
    end:
    if (sqe->user_data != 0)
        free((void *) sqe->user_data);
}

int mainloop_threaded(CLIENT *client)
{
    while (1) {
        struct nbd_request req;

        int rlt = readit(client->net, &req, sizeof(struct nbd_request));
        if (rlt < 0)
            return -1;

        if (req.magic != htonl(NBD_REQUEST_MAGIC))
            err("Protocol error: not enough magic.");

        struct expand_sqe *sqe = sqe_create(&req);

        if (sqe->opcode == NBD_CMD_WRITE) {
            readit(client->net, (void *) sqe->user_data, sqe->len);
        }
        if (sqe->opcode == NBD_CMD_DISC) {
            printf("req: magic=0x%x, type=0x%x, disconnect sock\n", htonl(req.magic), sqe->opcode);
            client->exp_io.close(&client->exp_io);
            return 0;
        }
//        case NBD_CMD_READ:
        handle_request(client, sqe);
//        g_thread_pool_push(tpool, pkg, NULL);
    }
}

